import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as m
from layer_afa import AFA_layer_cam_data, AFA_layer_sam_data, AFA_layer_cam, AFA_layer_sam


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, A_downsample=None, B_downsample=None, BatchNorm=None):
        super(Bottleneck, self).__init__()
        self.planes = planes
        self.dilation = dilation
        self.A_conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.A_bn1 = BatchNorm(planes)
        self.A_conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                 dilation=dilation, padding=dilation, bias=False)
        self.A_bn2 = BatchNorm(planes)
        self.A_conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.A_bn3 = BatchNorm(planes * 4)
        self.A_relu = nn.ReLU(inplace=True)

        self.B_conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.B_bn1 = BatchNorm(planes)
        self.B_conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                 dilation=dilation, padding=dilation, bias=False)
        self.B_bn2 = BatchNorm(planes)
        self.B_conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.B_bn3 = BatchNorm(planes * 4)
        self.B_relu = nn.ReLU(inplace=True)

        self.A_downsample = A_downsample
        self.B_downsample = B_downsample
        self.stride = stride
        if self.planes == 512:
            # self.afa_cam = AFA_layer_cam(512)
            # self.afa_sam = AFA_layer_sam(361)
            self.afa_cam = AFA_layer_cam_data(channels=512)
            self.afa_sam = AFA_layer_sam_data(h=8, w=16)   # modify here!!!

    def forward(self, x):
        x1, x2 = x[0], x[1]
        residual1 = x1
        residual2 = x2

        out1 = self.A_conv1(x1)
        out1 = self.A_bn1(out1)
        out1 = self.A_relu(out1)

        out1 = self.A_conv2(out1)
        out1 = self.A_bn2(out1)
        out1 = self.A_relu(out1)

        out1 = self.A_conv3(out1)
        out1 = self.A_bn3(out1)

        out2 = self.B_conv1(x2)
        out2 = self.B_bn1(out2)
        out2 = self.B_relu(out2)

        out2 = self.B_conv2(out2)
        out2 = self.B_bn2(out2)
        out2 = self.B_relu(out2)

        out2 = self.B_conv3(out2)
        out2 = self.B_bn3(out2)

        if self.A_downsample is not None:
            residual1 = self.A_downsample(x1)
        if self.B_downsample is not None:
            residual2 = self.B_downsample(x2)
        out1 = out1 + residual1
        out2 = out2 + residual2
        out1 = self.A_relu(out1)
        out2 = self.B_relu(out2)

        if self.planes == 512:
            out1, out2 = self.afa_cam(out1, out2)
            out1, out2 = self.afa_sam(out1, out2)

        return out1, out2


class ResNet(nn.Module):

    def __init__(self, block, layers, BatchNorm, output_stride=16, pretrained=True):
        self.inplanes = 64
        super(ResNet, self).__init__()
        blocks = [1, 2, 4]
        if output_stride == 16:
            strides = [1, 2, 2, 1]
            dilations = [1, 1, 1, 2]

        # Modules
        self.A_conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.A_bn1 = BatchNorm(64)
        self.A_relu = nn.ReLU(inplace=True)
        self.A_maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.B_conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.B_bn1 = BatchNorm(64)
        self.B_relu = nn.ReLU(inplace=True)
        self.B_maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)

        self._init_weight()

        if pretrained:
            self._load_pretrained_model()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
        A_downsample = None
        B_downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            A_downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )
            B_downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation, A_downsample=A_downsample, B_downsample=B_downsample, BatchNorm=BatchNorm))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
        return nn.Sequential(*layers)

    def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
        A_downsample = None
        B_downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            A_downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )
            B_downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
                            A_downsample=A_downsample, B_downsample=B_downsample, BatchNorm=BatchNorm))
        self.inplanes = planes * block.expansion
        for i in range(1, len(blocks)):
            layers.append(block(self.inplanes, planes, stride=1,
                                dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
        return nn.Sequential(*layers)

    def forward(self, input):
        x1 = self.A_conv1(input)
        x1 = self.A_bn1(x1)
        x1 = self.A_relu(x1)
        x1 = self.A_maxpool(x1)

        x2 = self.B_conv1(input)
        x2 = self.B_bn1(x2)
        x2 = self.B_relu(x2)
        x2 = self.B_maxpool(x2)
        x = (x1, x2)
        x1, x2 = self.layer1(x)
        low_level_feat1 = x1
        low_level_feat2 = x2
        x = (x1, x2)
        x1, x2 = self.layer2(x)
        x = (x1, x2)
        x1, x2 = self.layer3(x)
        x = (x1, x2)
        x1, x2 = self.layer4(x)

        return x1, x2, low_level_feat1, low_level_feat2

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _load_pretrained_model(self):
        resnet50 = m.resnet50(pretrained=True)
        pretrain_dict = resnet50.state_dict()
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if 'A_' + k in state_dict:
                model_dict['A_' + k] = v

            if 'B_' + k in state_dict:
                model_dict['B_' + k] = v

            if k[:9] + 'A_' + k[9:] in state_dict:
                model_dict[k[:9] + 'A_' + k[9:]] = v

            if k[:9] + 'B_' + k[9:] in state_dict:
                model_dict[k[:9] + 'B_' + k[9:]] = v

        state_dict.update(model_dict)
        self.load_state_dict(state_dict)


def ResNet_backbone(output_stride, BatchNorm, pretrained=True):
    model = ResNet(Bottleneck, [3, 4, 6, 3], BatchNorm, output_stride, pretrained=pretrained)
    return model


class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
        super(_ASPPModule, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class ASPP(nn.Module):
    def __init__(self, output_stride, BatchNorm):
        super(ASPP, self).__init__()

        inplanes = 2048
        if output_stride == 16:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 12, 24, 36]
        else:
            raise NotImplementedError

        self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
        self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
        self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                             nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
                                             BatchNorm(256),
                                             nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
        self.bn1 = BatchNorm(256)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        self._init_weight()

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        return self.dropout(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight)


class Decoder(nn.Module):
    def __init__(self, num_classes, BatchNorm):
        super(Decoder, self).__init__()

        low_level_inplanes = 256

        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = BatchNorm(48)
        self.relu = nn.ReLU(inplace=True)
        self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.5),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.1))
        self.my_classify = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
        self._init_weight()

    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, low_level_feat), dim=1)
        x = self.last_conv(x)
        x = self.my_classify(x)

        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class AFANet(nn.Module):
    def __init__(self, output_stride=16, freeze_bn=False, dataset='CityScape'):
        super(AFANet, self).__init__()
        
        if dataset == 'NYUv2':
            self.class_nb = 13
            self.tasks = ['segmentation', 'depth', 'normal']
            self.num_out_channels = {'segmentation': 13, 'depth': 1, 'normal': 3}
        elif dataset == 'CityScape':
            self.class_nb = 7
            self.tasks = ['segmentation', 'depth']
            self.num_out_channels = {'segmentation': 7, 'depth': 1}
        else:
            raise('No support {} dataset'.format(dataset))

        BatchNorm = nn.BatchNorm2d

        self.backbone = ResNet_backbone(output_stride, BatchNorm)
        self.A_aspp = ASPP(output_stride, BatchNorm)
        self.B_aspp = ASPP(output_stride, BatchNorm)
        self.A_decoder = Decoder(self.num_out_channels['segmentation'], BatchNorm)
        self.B_decoder = Decoder(self.num_out_channels['depth'], BatchNorm)

        if freeze_bn:
            self.freeze_bn()

    def forward(self, input):
        x1, x2, low_level_feat1, low_level_feat2 = self.backbone(input)
        x1 = self.A_aspp(x1)
        x2 = self.B_aspp(x2)
        x1 = self.A_decoder(x1, low_level_feat1)
        x2 = self.B_decoder(x2, low_level_feat2)
        x1 = F.interpolate(x1, size=input.size()[2:], mode='bilinear', align_corners=True)
        x1 = F.log_softmax(x1, dim=1)
        x2 = F.interpolate(x2, size=input.size()[2:], mode='bilinear', align_corners=True)
        return x1, x2
        
    def predict(self, input):
        return self.forward(input)

    def _freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def get_1x_lr_params(self):
        modules = [self.backbone]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.Linear) or isinstance(m[1], nn.BatchNorm1d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_10x_lr_params(self):
        modules = [self.A_aspp, self.B_aspp, self.A_decoder, self.B_decoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p